{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Transformers\n",
    "To understand anything that's going on below, check first the slides / video lesson."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch \n",
    "from torch import nn\n",
    "import torch.nn.functional as f\n",
    "import numpy as np "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
    "f.softargmax = f.softmax  # fix wrong name"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Multi head attention"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "metadata": {},
   "outputs": [],
   "source": [
    "class MultiHeadAttention(nn.Module):\n",
    "    def __init__(self, d_model, num_heads, p, d_input=None):\n",
    "        super().__init__()\n",
    "        self.num_heads = num_heads\n",
    "        self.d_model = d_model\n",
    "        if d_input is None:\n",
    "            d_xq = d_xk = d_xv = d_model\n",
    "        else:\n",
    "            d_xq, d_xk, d_xv = d_input\n",
    "            \n",
    "        # Make sure that the embedding dimension of model is a multiple of number of heads\n",
    "        assert d_model % self.num_heads == 0\n",
    "\n",
    "        self.d_k = d_model // self.num_heads\n",
    "        \n",
    "        # These are still of dimension d_model. They will be split into number of heads \n",
    "        self.W_q = nn.Linear(d_xq, d_model, bias=False)\n",
    "        self.W_k = nn.Linear(d_xk, d_model, bias=False)\n",
    "        self.W_v = nn.Linear(d_xv, d_model, bias=False)\n",
    "        \n",
    "        # Outputs of all sub-layers need to be of dimension d_model\n",
    "        self.W_h = nn.Linear(d_model, d_model)\n",
    "        \n",
    "    def scaled_dot_product_attention(self, Q, K, V):\n",
    "        batch_size = Q.size(0) \n",
    "        k_length = K.size(-2) \n",
    "        \n",
    "        # Scaling by d_k so that the softargmax doesnt saturate\n",
    "        Q = Q / np.sqrt(self.d_k)                   # (bs, n_heads, q_length, dim_per_head)\n",
    "        scores = torch.matmul(Q, K.transpose(2,3))  # (bs, n_heads, q_length, k_length)\n",
    "        \n",
    "        A = f.softargmax(scores, dim=-1)            # (bs, n_heads, q_length, k_length)\n",
    "        \n",
    "        # Get the weighted average of the values\n",
    "        H = torch.matmul(A, V)                      # (bs, n_heads, q_length, dim_per_head)\n",
    "\n",
    "        return H, A \n",
    "\n",
    "        \n",
    "    def split_heads(self, x, batch_size):\n",
    "        \"\"\"\n",
    "        Split the last dimension into (heads X depth)\n",
    "        Return after transpose to put in shape (batch_size X num_heads X seq_length X d_k)\n",
    "        \"\"\"\n",
    "        return x.view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)\n",
    "\n",
    "    def group_heads(self, x, batch_size):\n",
    "        \"\"\"\n",
    "        Combine the heads again to get (batch_size X seq_length X (num_heads times d_k))\n",
    "        \"\"\"\n",
    "        return x.transpose(1, 2).contiguous().view(batch_size, -1, self.num_heads * self.d_k)\n",
    "    \n",
    "\n",
    "    def forward(self, X_q, X_k, X_v):\n",
    "        batch_size, seq_length, dim = X_q.size()\n",
    "\n",
    "        # After transforming, split into num_heads \n",
    "        Q = self.split_heads(self.W_q(X_q), batch_size)  # (bs, n_heads, q_length, dim_per_head)\n",
    "        K = self.split_heads(self.W_k(X_k), batch_size)  # (bs, n_heads, k_length, dim_per_head)\n",
    "        V = self.split_heads(self.W_v(X_v), batch_size)  # (bs, n_heads, v_length, dim_per_head)\n",
    "        \n",
    "        # Calculate the attention weights for each of the heads\n",
    "        H_cat, A = self.scaled_dot_product_attention(Q, K, V)\n",
    "        \n",
    "        # Put all the heads back together by concat\n",
    "        H_cat = self.group_heads(H_cat, batch_size)    # (bs, q_length, dim)\n",
    "        \n",
    "        # Final linear layer  \n",
    "        H = self.W_h(H_cat)          # (bs, q_length, dim)\n",
    "        \n",
    "        return H, A"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Some sanity checks:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "temp_mha = MultiHeadAttention(d_model=512, num_heads=8, p=0)\n",
    "def print_out(Q, K, V):\n",
    "    temp_out, temp_attn = temp_mha.scaled_dot_product_attention(Q, K, V)\n",
    "    print('Attention weights are:', temp_attn.squeeze())\n",
    "    print('Output is:', temp_out.squeeze())"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "To check our self attention works - if the query matches with one of the key values, it should have all the attention focused there, with the value returned as the value at that index"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Attention weights are: tensor([3.7266e-06, 9.9999e-01, 3.7266e-06, 3.7266e-06])\n",
      "Output is: tensor([1.0004e+01, 4.0993e-05, 0.0000e+00])\n"
     ]
    }
   ],
   "source": [
    "test_K = torch.tensor(\n",
    "    [[10, 0, 0],\n",
    "     [ 0,10, 0],\n",
    "     [ 0, 0,10],\n",
    "     [ 0, 0,10]]\n",
    ").float()[None,None]\n",
    "\n",
    "test_V = torch.tensor(\n",
    "    [[   1,0,0],\n",
    "     [  10,0,0],\n",
    "     [ 100,5,0],\n",
    "     [1000,6,0]]\n",
    ").float()[None,None]\n",
    "\n",
    "test_Q = torch.tensor(\n",
    "    [[0, 10, 0]]\n",
    ").float()[None,None]\n",
    "print_out(test_Q, test_K, test_V)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Great! We can see that it focuses on the second key and returns the second value. \n",
    "\n",
    "If we give a query that matches two keys exactly, it should return the averaged value of the two values for those two keys. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Attention weights are: tensor([1.8633e-06, 1.8633e-06, 5.0000e-01, 5.0000e-01])\n",
      "Output is: tensor([549.9979,   5.5000,   0.0000])\n"
     ]
    }
   ],
   "source": [
    "test_Q = torch.tensor([[0, 0, 10]]).float()  \n",
    "print_out(test_Q, test_K, test_V)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We see that it focuses equally on the third and fourth key and returns the average of their values.\n",
    "\n",
    "Now giving all the queries at the same time:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Attention weights are: tensor([[1.8633e-06, 1.8633e-06, 5.0000e-01, 5.0000e-01],\n",
      "        [3.7266e-06, 9.9999e-01, 3.7266e-06, 3.7266e-06],\n",
      "        [5.0000e-01, 5.0000e-01, 1.8633e-06, 1.8633e-06]])\n",
      "Output is: tensor([[5.5000e+02, 5.5000e+00, 0.0000e+00],\n",
      "        [1.0004e+01, 4.0993e-05, 0.0000e+00],\n",
      "        [5.5020e+00, 2.0497e-05, 0.0000e+00]])\n"
     ]
    }
   ],
   "source": [
    "test_Q = torch.tensor(\n",
    "    [[0, 0, 10], [0, 10, 0], [10, 10, 0]]\n",
    ").float()[None,None]\n",
    "print_out(test_Q, test_K, test_V)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 1D convolution with `kernel_size = 1`"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "This is basically an MLP with one hidden layer and ReLU activation applied to each and every element in the set."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "class CNN(nn.Module):\n",
    "    def __init__(self, d_model, hidden_dim, p):\n",
    "        super().__init__()\n",
    "        self.k1convL1 = nn.Linear(d_model,    hidden_dim)\n",
    "        self.k1convL2 = nn.Linear(hidden_dim, d_model)\n",
    "        self.activation = nn.ReLU()\n",
    "\n",
    "    def forward(self, x):\n",
    "        x = self.k1convL1(x)\n",
    "        x = self.activation(x)\n",
    "        x = self.k1convL2(x)\n",
    "        return x"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Transformer encoder"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now we have all components for our Transformer Encoder block shown below!!!!"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "class EncoderLayer(nn.Module):\n",
    "    def __init__(self, d_model, num_heads, conv_hidden_dim, p=0.1):\n",
    "        super().__init__()\n",
    "\n",
    "        self.mha = MultiHeadAttention(d_model, num_heads, p)\n",
    "        self.cnn = CNN(d_model, conv_hidden_dim, p)\n",
    "\n",
    "        self.layernorm1 = nn.LayerNorm(normalized_shape=d_model, eps=1e-6)\n",
    "        self.layernorm2 = nn.LayerNorm(normalized_shape=d_model, eps=1e-6)\n",
    "    \n",
    "    def forward(self, x):\n",
    "        \n",
    "        # Multi-head attention \n",
    "        attn_output, _ = self.mha(x, x, x)  # (batch_size, input_seq_len, d_model)\n",
    "        \n",
    "        # Layer norm after adding the residual connection \n",
    "        out1 = self.layernorm1(x + attn_output)  # (batch_size, input_seq_len, d_model)\n",
    "        \n",
    "        # Feed forward \n",
    "        cnn_output = self.cnn(out1)  # (batch_size, input_seq_len, d_model)\n",
    "        \n",
    "        #Second layer norm after adding residual connection \n",
    "        out2 = self.layernorm2(out1 + cnn_output)  # (batch_size, input_seq_len, d_model)\n",
    "\n",
    "        return out2"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Encoder \n",
    "#### Blocks of N Encoder Layers + Positional encoding + Input embedding"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Self attention by itself does not have any recurrence or convolutions so to make it sensitive to position we must provide additional positional encodings. These are calculated as follows:\n",
    "\n",
    "\\begin{aligned}\n",
    "E(p, 2i)    &= \\sin(p / 10000^{2i / d}) \\\\\n",
    "E(p, 2i+1) &= \\cos(p / 10000^{2i / d})\n",
    "\\end{aligned}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [],
   "source": [
    "def create_sinusoidal_embeddings(nb_p, dim, E):\n",
    "    theta = np.array([\n",
    "        [p / np.power(10000, 2 * (j // 2) / dim) for j in range(dim)]\n",
    "        for p in range(nb_p)\n",
    "    ])\n",
    "    E[:, 0::2] = torch.FloatTensor(np.sin(theta[:, 0::2]))\n",
    "    E[:, 1::2] = torch.FloatTensor(np.cos(theta[:, 1::2]))\n",
    "    E.requires_grad = False\n",
    "    E = E.to(device)\n",
    "\n",
    "class Embeddings(nn.Module):\n",
    "    def __init__(self, d_model, vocab_size, max_position_embeddings, p):\n",
    "        super().__init__()\n",
    "        self.word_embeddings = nn.Embedding(vocab_size, d_model, padding_idx=1)\n",
    "        self.position_embeddings = nn.Embedding(max_position_embeddings, d_model)\n",
    "        create_sinusoidal_embeddings(\n",
    "            nb_p=max_position_embeddings,\n",
    "            dim=d_model,\n",
    "            E=self.position_embeddings.weight\n",
    "        )\n",
    "\n",
    "        self.LayerNorm = nn.LayerNorm(d_model, eps=1e-12)\n",
    "\n",
    "    def forward(self, input_ids):\n",
    "        seq_length = input_ids.size(1)\n",
    "        position_ids = torch.arange(seq_length, dtype=torch.long, device=input_ids.device) # (max_seq_length)\n",
    "        position_ids = position_ids.unsqueeze(0).expand_as(input_ids)                      # (bs, max_seq_length)\n",
    "        \n",
    "        # Get word embeddings for each input id\n",
    "        word_embeddings = self.word_embeddings(input_ids)                   # (bs, max_seq_length, dim)\n",
    "        \n",
    "        # Get position embeddings for each position id \n",
    "        position_embeddings = self.position_embeddings(position_ids)        # (bs, max_seq_length, dim)\n",
    "        \n",
    "        # Add them both \n",
    "        embeddings = word_embeddings + position_embeddings  # (bs, max_seq_length, dim)\n",
    "        \n",
    "        # Layer norm \n",
    "        embeddings = self.LayerNorm(embeddings)             # (bs, max_seq_length, dim)\n",
    "        return embeddings"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "class Encoder(nn.Module):\n",
    "    def __init__(self, num_layers, d_model, num_heads, ff_hidden_dim, input_vocab_size,\n",
    "               maximum_position_encoding, p=0.1):\n",
    "        super().__init__()\n",
    "\n",
    "        self.d_model = d_model\n",
    "        self.num_layers = num_layers\n",
    "\n",
    "        self.embedding = Embeddings(d_model, input_vocab_size,maximum_position_encoding, p)\n",
    "\n",
    "        self.enc_layers = nn.ModuleList()\n",
    "        for _ in range(num_layers):\n",
    "            self.enc_layers.append(EncoderLayer(d_model, num_heads, ff_hidden_dim, p))\n",
    "        \n",
    "    def forward(self, x):\n",
    "        x = self.embedding(x) # Transform to (batch_size, input_seq_length, d_model)\n",
    "\n",
    "        for i in range(self.num_layers):\n",
    "            x = self.enc_layers[i](x)\n",
    "\n",
    "        return x  # (batch_size, input_seq_len, d_model)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "from torchtext import legacy\n",
    "from torchtext.legacy import data\n",
    "import torchtext.datasets as datasets"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "train :  25000\n",
      "test :  25000\n",
      "train.fields : {'text': <torchtext.data.field.Field object at 0x1287ccfa0>, 'label': <torchtext.data.field.LabelField object at 0x1287ccfd0>}\n"
     ]
    }
   ],
   "source": [
    "max_len = 200\n",
    "text = legacy.data.Field(sequential=True, fix_length=max_len, batch_first=True, lower=True, dtype=torch.long)\n",
    "label = legacy.data.LabelField(sequential=False, dtype=torch.long)\n",
    "ds_train, ds_test = legacy.datasets.IMDB.splits(text, label, root='./')\n",
    "print('train : ', len(ds_train))\n",
    "print('test : ', len(ds_test))\n",
    "print('train.fields :', ds_train.fields)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "train :  22500\n",
      "valid :  2500\n",
      "test :  25000\n"
     ]
    }
   ],
   "source": [
    "ds_train, ds_valid = ds_train.split(0.9)\n",
    "print('train : ', len(ds_train))\n",
    "print('valid : ', len(ds_valid))\n",
    "print('test : ', len(ds_test))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [],
   "source": [
    "num_words = 50_000\n",
    "text.build_vocab(ds_train, max_size=num_words)\n",
    "label.build_vocab(ds_train)\n",
    "vocab = text.vocab"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [],
   "source": [
    "batch_size = 164\n",
    "train_loader, valid_loader, test_loader = data.BucketIterator.splits(\n",
    "    (ds_train, ds_valid, ds_test), batch_size=batch_size, sort_key=lambda x: len(x.text), repeat=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {},
   "outputs": [],
   "source": [
    "class TransformerClassifier(nn.Module):\n",
    "    def __init__(self, num_layers, d_model, num_heads, conv_hidden_dim, input_vocab_size, num_answers):\n",
    "        super().__init__()\n",
    "        \n",
    "        self.encoder = Encoder(num_layers, d_model, num_heads, conv_hidden_dim, input_vocab_size,\n",
    "                         maximum_position_encoding=10000)\n",
    "        self.dense = nn.Linear(d_model, num_answers)\n",
    "\n",
    "    def forward(self, x):\n",
    "        x = self.encoder(x)\n",
    "        \n",
    "        x, _ = torch.max(x, dim=1)\n",
    "        x = self.dense(x)\n",
    "        return x"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "TransformerClassifier(\n",
       "  (encoder): Encoder(\n",
       "    (embedding): Embeddings(\n",
       "      (word_embeddings): Embedding(50002, 32, padding_idx=1)\n",
       "      (position_embeddings): Embedding(10000, 32)\n",
       "      (LayerNorm): LayerNorm((32,), eps=1e-12, elementwise_affine=True)\n",
       "    )\n",
       "    (enc_layers): ModuleList(\n",
       "      (0): EncoderLayer(\n",
       "        (mha): MultiHeadAttention(\n",
       "          (W_q): Linear(in_features=32, out_features=32, bias=False)\n",
       "          (W_k): Linear(in_features=32, out_features=32, bias=False)\n",
       "          (W_v): Linear(in_features=32, out_features=32, bias=False)\n",
       "          (W_h): Linear(in_features=32, out_features=32, bias=True)\n",
       "        )\n",
       "        (cnn): CNN(\n",
       "          (k1convL1): Linear(in_features=32, out_features=128, bias=True)\n",
       "          (k1convL2): Linear(in_features=128, out_features=32, bias=True)\n",
       "          (activation): ReLU()\n",
       "        )\n",
       "        (layernorm1): LayerNorm((32,), eps=1e-06, elementwise_affine=True)\n",
       "        (layernorm2): LayerNorm((32,), eps=1e-06, elementwise_affine=True)\n",
       "      )\n",
       "      (1): EncoderLayer(\n",
       "        (mha): MultiHeadAttention(\n",
       "          (W_q): Linear(in_features=32, out_features=32, bias=False)\n",
       "          (W_k): Linear(in_features=32, out_features=32, bias=False)\n",
       "          (W_v): Linear(in_features=32, out_features=32, bias=False)\n",
       "          (W_h): Linear(in_features=32, out_features=32, bias=True)\n",
       "        )\n",
       "        (cnn): CNN(\n",
       "          (k1convL1): Linear(in_features=32, out_features=128, bias=True)\n",
       "          (k1convL2): Linear(in_features=128, out_features=32, bias=True)\n",
       "          (activation): ReLU()\n",
       "        )\n",
       "        (layernorm1): LayerNorm((32,), eps=1e-06, elementwise_affine=True)\n",
       "        (layernorm2): LayerNorm((32,), eps=1e-06, elementwise_affine=True)\n",
       "      )\n",
       "      (2): EncoderLayer(\n",
       "        (mha): MultiHeadAttention(\n",
       "          (W_q): Linear(in_features=32, out_features=32, bias=False)\n",
       "          (W_k): Linear(in_features=32, out_features=32, bias=False)\n",
       "          (W_v): Linear(in_features=32, out_features=32, bias=False)\n",
       "          (W_h): Linear(in_features=32, out_features=32, bias=True)\n",
       "        )\n",
       "        (cnn): CNN(\n",
       "          (k1convL1): Linear(in_features=32, out_features=128, bias=True)\n",
       "          (k1convL2): Linear(in_features=128, out_features=32, bias=True)\n",
       "          (activation): ReLU()\n",
       "        )\n",
       "        (layernorm1): LayerNorm((32,), eps=1e-06, elementwise_affine=True)\n",
       "        (layernorm2): LayerNorm((32,), eps=1e-06, elementwise_affine=True)\n",
       "      )\n",
       "    )\n",
       "  )\n",
       "  (dense): Linear(in_features=32, out_features=2, bias=True)\n",
       ")"
      ]
     },
     "execution_count": 29,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model = TransformerClassifier(num_layers=1, d_model=32, num_heads=2, \n",
    "                         conv_hidden_dim=128, input_vocab_size=50002, num_answers=2)\n",
    "model.to(device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "metadata": {},
   "outputs": [],
   "source": [
    "optimizer = torch.optim.AdamW(model.parameters(), lr=0.001)\n",
    "epochs = 10\n",
    "t_total = len(train_loader) * epochs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 31,
   "metadata": {},
   "outputs": [],
   "source": [
    "def train(train_loader, valid_loader):\n",
    "    \n",
    "    for epoch in range(epochs):\n",
    "        train_iterator, valid_iterator = iter(train_loader), iter(valid_loader)\n",
    "        nb_batches_train = len(train_loader)\n",
    "        train_acc = 0\n",
    "        model.train()\n",
    "        losses = 0.0\n",
    "\n",
    "        for batch in train_iterator:\n",
    "            x = batch.text.to(device)\n",
    "            y = batch.label.to(device)\n",
    "            \n",
    "            out = model(x)  # ①\n",
    "\n",
    "            loss = f.cross_entropy(out, y)  # ②\n",
    "            \n",
    "            model.zero_grad()  # ③\n",
    "\n",
    "            loss.backward()  # ④\n",
    "            losses += loss.item()\n",
    "\n",
    "            optimizer.step()  # ⑤\n",
    "                        \n",
    "            train_acc += (out.argmax(1) == y).cpu().numpy().mean()\n",
    "        \n",
    "        print(f\"Training loss at epoch {epoch} is {losses / nb_batches_train}\")\n",
    "        print(f\"Training accuracy: {train_acc / nb_batches_train}\")\n",
    "        print('Evaluating on validation:')\n",
    "        evaluate(valid_loader)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 32,
   "metadata": {},
   "outputs": [],
   "source": [
    "def evaluate(data_loader):\n",
    "    data_iterator = iter(data_loader)\n",
    "    nb_batches = len(data_loader)\n",
    "    model.eval()\n",
    "    acc = 0 \n",
    "    for batch in data_iterator:\n",
    "        x = batch.text.to(device)\n",
    "        y = batch.label.to(device)\n",
    "                \n",
    "        out = model(x)\n",
    "        acc += (out.argmax(1) == y).cpu().numpy().mean()\n",
    "\n",
    "    print(f\"Eval accuracy: {acc / nb_batches}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 33,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Training loss at epoch 0 is 0.7040774178330915\n",
      "Training accuracy: 0.563751780309774\n",
      "Evaluating on validation:\n",
      "Eval accuracy: 0.6621138211382114\n",
      "Training loss at epoch 1 is 0.6331876960113971\n",
      "Training accuracy: 0.6726010325796692\n",
      "Evaluating on validation:\n",
      "Eval accuracy: 0.7393902439024391\n",
      "Training loss at epoch 2 is 0.5574538331397259\n",
      "Training accuracy: 0.7341385526081536\n",
      "Evaluating on validation:\n",
      "Eval accuracy: 0.791910569105691\n",
      "Training loss at epoch 3 is 0.4765370425081601\n",
      "Training accuracy: 0.7838870838525904\n",
      "Evaluating on validation:\n",
      "Eval accuracy: 0.8216666666666668\n",
      "Training loss at epoch 4 is 0.4052363144655297\n",
      "Training accuracy: 0.8277216485668509\n",
      "Evaluating on validation:\n",
      "Eval accuracy: 0.8325609756097563\n",
      "Training loss at epoch 5 is 0.3416111364851903\n",
      "Training accuracy: 0.8608354103614029\n",
      "Evaluating on validation:\n",
      "Eval accuracy: 0.8455284552845529\n",
      "Training loss at epoch 6 is 0.2870733802118441\n",
      "Training accuracy: 0.888063022965996\n",
      "Evaluating on validation:\n",
      "Eval accuracy: 0.8582113821138211\n",
      "Training loss at epoch 7 is 0.2374339928157138\n",
      "Training accuracy: 0.9137384279864698\n",
      "Evaluating on validation:\n",
      "Eval accuracy: 0.8613821138211383\n",
      "Training loss at epoch 8 is 0.19403393457840828\n",
      "Training accuracy: 0.9343733309595865\n",
      "Evaluating on validation:\n",
      "Eval accuracy: 0.8569105691056911\n",
      "Training loss at epoch 9 is 0.15626875180615127\n",
      "Training accuracy: 0.9517257877870745\n",
      "Evaluating on validation:\n",
      "Eval accuracy: 0.8609349593495935\n"
     ]
    }
   ],
   "source": [
    "train(train_loader, valid_loader)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 34,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Eval accuracy: 0.8047465589787477\n"
     ]
    }
   ],
   "source": [
    "evaluate(test_loader)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
